#import matplotlib.pyplot as plt
import numpy as np
from algo.feature.fill_nulldata import liner_interpolate
from utils.format_util import dup_name_handler
import logging
import math

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")


def moving_average(data, window_size):
	window = np.ones(int(window_size)) / float(window_size)
	return np.convolve(data, window, 'same')


def exp_ma(data, beta):
	N = len(data)
	data_filtered = np.zeros(len(data) + 1)
	data_filtered_corrected = np.copy(data_filtered)
	bias_correction_idx = 0

	data_filtered[0] = 0
	for i in range(N):
		if (data[i] - data[i - 1]) > 0.5:
			data_filtered[i + 1] = 0 + (1 - beta) * data[i]
			bias_correction_idx = i
		else:
			data_filtered[i + 1] = beta * data_filtered[i] + (1 - beta) * data[i]
		# bias correction
		data_filtered_corrected[i + 1] = data_filtered[i + 1] / (1 - beta ** (i + 1 - bias_correction_idx))

	return data_filtered_corrected[1:]


"""
* 创建系数矩阵X
* size - 2×size+1 = window_size
* rank - 拟合多项式阶次
* x - 创建的系数矩阵
"""


def create_x(size, rank):
	x = []
	for i in range(2 * size + 1):
		m = i - size
		row = [m ** j for j in range(rank)]
		x.append(row)
	x = np.mat(x)
	return x


def savitzky_golay(data, window_size, rank):
	m = int((window_size - 1) / 2)
	odata = data[:]
	# 处理边缘数据，首尾增加m个首尾项
	for i in range(m):
		np.insert(odata, 0, odata[0])
		np.insert(odata, len(odata), odata[len(odata) - 1])
	# 创建X矩阵
	x = create_x(m, rank)
	# 计算加权系数矩阵B
	b = (x * (x.T * x).I) * x.T
	a0 = b[m]
	a0 = a0.T
	# 计算平滑修正后的值
	ndata = []
	for i in range(len(data) - window_size):
		y = [odata[i + j] for j in range(window_size)]
		y1 = np.mat(y) * a0
		y1 = float(y1)
		ndata.append(y1)
	ws_rate = window_size / len(data)
	new_window_size = math.ceil(ws_rate * m)

	if (3 * new_window_size) < m:
		new_window_size = 2 * new_window_size

	head_data = moving_average(data[:m + 1], max(3, new_window_size))[1:]
	tail_data = moving_average(data[-m - 1:], max(3, new_window_size))[:]
	ndata[0:0] = head_data
	ndata[len(ndata):len(ndata)] = tail_data
	return ndata[:len(data)]


def oddCheck(window_size):
	window_size = int(window_size)
	if window_size % 2 == 0:
		window_size = window_size + 1
	return window_size


def run(df, params):
	results = {"status": "SUCCESS"}
	col = params.get("col")
	method = params.get("method")
	#df = liner_interpolate(df, [col])

	data = df[col].values
	new_data = None

	try:
		if method == 'ma':
			new_data = moving_average(data, oddCheck(params.get("ma_window_size")))
		elif method == 'ema':
			new_data = exp_ma(data, float(params.get("ema_decay")))
		elif method == 'sg':
			new_data = savitzky_golay(data, oddCheck(int(params.get("sg_window_size"))), int(params.get("sg_rank")))
		else:
			results["error_msg"] = "wrong method"
		if new_data is not None:
			logging.info("smoothing algorithm finished")
			if True:
				all_cols = df.columns.tolist()
				new_col = '_'.join([col, method])
				new_col = dup_name_handler(new_col, all_cols)

				index = all_cols.index(col) + 1
				all_cols.insert(index, new_col)
				df = df.reindex(columns=all_cols)
				df[new_col] = new_data

	except Exception as e:
		results["error_msg"] = e
		results["status"] = "FAIL"
		logging.error(e)
	finally:
		return df, results

# t = np.linspace(start=-4, stop=4, num=100)
# y = np.sin(t) + np.random.randn(len(t)) * 0.1
#
# d = {"time":t,"data":y}
# import pandas as pd
# df = pd.DataFrame(data=d)
# df.to_csv("test_data.csv")


# window_size_test = oddCheck(30)
#
# rank_test = 3
# ysg = savitzky_golay(y, window_size_test, rank_test)
# plt.plot(t, y, "b.-", t, ysg, "r.-")
# plt.plot(t, y)
# plt.xlabel('Time')
# plt.ylabel('Value')
# plt.legend(['original data', 'smooth data'])
# plt.grid(True)
# plt.title("sg")
# plt.show()
#
# ysg = moving_average(y, window_size_test)
# plt.plot(t, y, "b.-", t, ysg, "r.-")
# plt.plot(t, y)
# plt.xlabel('Time')
# plt.ylabel('Value')
# plt.legend(['original data', 'smooth data'])
# plt.grid(True)
# plt.title("ma")
# plt.show()
#
# ysg = exp_ma(y, 0.2)
# plt.plot(t, y, "b.-", t, ysg, "r.-")
# plt.plot(t, y)
# plt.xlabel('Time')
# plt.ylabel('Value')
# plt.legend(['original data', 'smooth data'])
# plt.grid(True)
# plt.title("ema")
# plt.show()
